Run the LFADS algorithm on an RNN that integrates white noise.

The goal of this tutorial is to learn about LFADS by running the algorithm on a simple data generator, a vanilla recurrent neural network (RNN) that was trained to integrate a white noise input. Running LFADS on this integrator RNN will infer two things:

  1. the underlying hidden state of the integrator RNN
  2. the white noise input to the integrator RNN.

Doing this will exercise the more complex LFADS architecture that is shown in Figure 5 of the LFADS paper. It's pretty important that you have read at least the introduction of the paper, otherwise, you won't understand why we are doing what we are doing.

In this tutorial we do a few things:

  1. Load the integrator RNN data and "spikify" it by treating the hidden units as nonhomogeneous Poisson processes.
  2. Explain a bit of the LFADS architecture and highlight some of the relevant hyperparameters.
  3. Train the LFADS system on the spikified integrator RNN hidden states.
  4. Plot a whole bunch of training plots and LFADS outputs!

If you make it through this tutorial and understand everything in it, it is highly likely you'll be able to run LFADS on your own data.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at

 https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

Import the tutorial code.

If you are going to actually run the tutorial, you have to install JAX, download the computation thru dynamics GitHub repo, and modify a path.


In [1]:
# Numpy, JAX, Matplotlib and h5py should all be correctly installed and on the python path.
from __future__ import print_function, division, absolute_import

import datetime
import h5py
import jax.numpy as np
from jax import random
from jax.experimental import optimizers
from jax.config import config
#config.update("jax_debug_nans", True) # Useful for finding numerical errors
import matplotlib.pyplot as plt
import numpy as onp  # original CPU-backed NumPy
import scipy.signal
import scipy.stats
import os
import sys
import time

In [2]:
# You must change this to the location of the computation-thru-dynamics directory.
HOME_DIR = '/home/sussillo/' 

sys.path.append(os.path.join(HOME_DIR,'computation-thru-dynamics'))
import lfads_tutorial.lfads as lfads
import lfads_tutorial.plotting as plotting
import lfads_tutorial.utils as utils
from lfads_tutorial.optimize import optimize_lfads, get_kl_warmup_fun

Preliminaries - notes on using JAX

JAX is amazing! It's really, really AMAZING! You program in Numpy/Python and then call a grad on your code, and it'll run speedy on GPUs! It does however have a few quirks and it uses a program deployment model you have to know about. The excited reader should definitely read the JAX tutorial if they plan on programming with it.

When using JAX for auto diff, auto batching or compiling, you should always have a two-level mental model in your mind:

  1. At the CPU level, like normal
  2. at the device level, for example a GPU.

Since JAX compiles your code to device, it is very efficient but creates this split. Thus, for example, we have two NumPY modules kicking around: 'onp' for 'original numpy', which is on the CPU, and np, which is the JAX modified version and runs 'on device'. This latter version of numpy is enabled to compute gradients and run your code quickly.

So the model then is: initialize variables, seeds, etc, at the CPU level, and dispatch a JAX based computation to the device. This all happens naturally whenever you call JAX enabled functions.

Thus one of the first things we do initialize the onp random number generator at the CPU level.


In [3]:
onp_rng = onp.random.RandomState(seed=0) # For CPU-based numpy randomness

Load the data

You must run through the integrator RNN tutorial notebook on your machine. Don't worry! It's much simpler than this tutorial! :)

Point to the correct data file for the integrator RNN. Note that the integrator rnn tutorial notebook creates two files, both the parameters file and the data file with examples.


In [4]:
INTEGRATOR_RNN_DATA_FILE = \
     '/tmp/vrnn/pure_int/trained_data_vrnn_pure_int_0.00002_2019-06-19_15:12:45.h5'
lfads_dir = '/tmp/lfads/'       # where to save lfads data and parameters to
rnn_type = 'lfads'
task_type = 'integrator'

In [5]:
# Make directories
data_dir = os.path.join(lfads_dir, 'data/')
output_dir = os.path.join(lfads_dir, 'output/')
figure_dir = os.path.join(lfads_dir, os.path.join(output_dir, 'figures/'))
if not os.path.exists(output_dir):
    os.makedirs(output_dir)
if not os.path.exists(figure_dir):
    os.makedirs(figure_dir)

# Load synthetic data
data_dict = utils.read_file(INTEGRATOR_RNN_DATA_FILE)

Plot examples and statistics about the integrator RNN data.


In [6]:
f = plotting.plot_data_pca(data_dict)


Number of data examples:  20480
Number of timesteps:  25
Number of data dimensions:  100

The goal of this tutorial is to infer the hiddens (blue), and input to the integrator RNN (umm... also blue).


In [7]:
f = plotting.plot_data_example(data_dict['inputs'], 
                               data_dict['hiddens'],
                               data_dict['outputs'], 
                               data_dict['targets'])


Spikify the synthetic data

The output of the integrator rnn is the continuous inputs, hidden states and outputs of the example. LFADS is a tool to infer underlying factors in spiking neural data, so we are going to "spikify" the integrator rnn example hidden states.

Data was generated w/ VRNN w/ tanh, thus $(\mbox{data}+1) / 2 \rightarrow [0,1]$. We put those activations between 0 and 1 here and then convert to spikes.


In [8]:
data_dt = 1.0/25.0        # define our dt in a physiological range

# If data is normed between 0 and 1, then a 1 yields this many 
# spikes per second. Pushing this downwards makes the problem harder.
max_firing_rate = 80      
train_fraction = 0.9      # Train with 90% of the synthetic data

renormed_fun = lambda x : (x + 1) / 2.0

renormed_data = renormed_fun(data_dict['hiddens'])

# When dimensions are relevant, I use a variable naming scheme like
# name_dim1xdim2x...  so below, here is the synthetic data with 
# 3 dimensions of batch, time and unit, in that order.
data_bxtxn = utils.spikify_data(renormed_data, onp_rng, data_dt,
                                max_firing_rate=max_firing_rate)
nexamples, ntimesteps, data_dim = data_bxtxn.shape

train_data, eval_data = utils.split_data(data_bxtxn,
                                         train_fraction=train_fraction)
eval_data_offset = int(train_fraction * data_bxtxn.shape[0])

In [9]:
eval_data.shape


Out[9]:
(2048, 25, 100)

Plot the statistics of the data.


In [10]:
f = plotting.plot_data_stats(data_dict, data_bxtxn, data_dt)


40.04489599609375 spikes/second

Let's study this single example of a single neuron's true firing rate (red) and the spikified version in the blue stem plot.


In [41]:
my_example_bidx = eval_data_offset + 0
my_example_hidx = 0
scale = max_firing_rate * data_dt
my_signal = scale*renormed_data[my_example_bidx, :, my_example_hidx]
my_signal_spikified = data_bxtxn[my_example_bidx, :, my_example_hidx]
plt.plot(my_signal, 'r');
plt.stem(my_signal_spikified);


If you were to increase max_firing_rate to infinity, the stem plot would approach the red line. This plot gives you an idea of how challenging the data set is, at least on single trials. We can think about this a little bit. If you were to simply filter the spikes, it definitely would not look like the red trace, at this low maximum firing rate. This means that if any technique were to have


In [42]:
nfilt = 3
my_filtered_spikes = scipy.signal.filtfilt(onp.ones(nfilt)/nfilt, 1, my_signal_spikified)
plt.plot(my_signal, 'r');
plt.plot(my_filtered_spikes);
plt.title("This looks terrible");
plt.legend(('True rate', 'Filtered spikes'));


This would force us to think about ways in which the population can be filtered. The first idea is naturally PCA. Perhaps there is a low-d subspace of signal that can be found in the high-variance top PCs. Using the entire trial, it's likely this should do better.


In [13]:
import sklearn
ncomponents = 100
full_pca = sklearn.decomposition.PCA(ncomponents)
full_pca.fit(onp.reshape(data_bxtxn, [-1, data_dim]))


Out[13]:
PCA(copy=True, iterated_power='auto', n_components=100, random_state=None,
  svd_solver='auto', tol=0.0, whiten=False)

In [14]:
plt.stem(full_pca.explained_variance_)
plt.title('Those top 2 PCs sure look promising!');



In [43]:
ncomponents = 2
pca = sklearn.decomposition.PCA(ncomponents)
pca.fit(onp.reshape(data_bxtxn[0:eval_data_offset,:,:], [-1, data_dim]))
my_example_pca = pca.transform(data_bxtxn[my_example_bidx,:,:])
my_example_ipca = pca.inverse_transform(my_example_pca)

In [44]:
plt.plot(my_signal, 'r')
plt.plot(my_example_ipca[:,my_example_hidx])
plt.legend(('True rate', 'PCA smoothed spikes'))
plt.title('This a bit better.');


So temporal filtering is not great, and spatial filtering helps only a bit. What to do? The idea LFADS explores is that if you knew the system that generated the data, you would be able to separate signal from noise, the signal being what a system can generate, the noise being the rest.


LFADS - Latent Factor Analysis via Dynamical Systems

Link to paper readcube version of the LFADS Nature Methods 2018 paper

LFADS architecture with inferred inputs

There are 3 variants of the LFADS architecture in the paper

  1. autonomous LFADS model (no inferred inputs), Fig. 1a
  2. stitched LFADS model for data recorded in different sessions, Fig. 4a
  3. non-autonomous LFADS model (with inferred inputs), Fig. 5a

In this tutorial, we deal with the non-autonomous model, which I believe is conceptually the most interesting, but also the most challenging to understand. This tutorial (and the current code), does NOT handle stitched data. Stitching data isn't conceptually hard, but it's a pain to code. The Tensorflow version of the code handles that if you need it.

Here is the non-autonoumous LFADS model architecture: The full description of this model is given in the paper but briefly, the idea is that the data LFADS will 'denoise' or model data generated from a nonlinear, autonoumous system (we call it the data generator and the data generator in this tutorial is the integrator RNN) that receives an input through time. Based on the spiking observations, LFADS will try to pull apart the data into the dynamical system portion, and the input portion, thus the term inferred inputs. I.e. we are trying to infer what inputs would drive a high-d nonlinear system to generate the data you've recorded. Doing this allows the system to model the dynamics much better for systems that are input-driven. One final detail is that the model assumes that the spikes are poisson generated from an underlying continuous dynamical system. Of course, this is not true for spiking data from biological neural circuits, but the poisson approximation seems to be ok.

So architecture infers a number of quantities of interest:

  1. initial state to the generator (also called initial conditions)
  2. inferred inputs to the generator - e.g. the LFADS component to learn the white noise in the integrator RNN example
  3. dynamical factors - these are like PCs underlying your data
  4. rates - a readout from the factors. The rates are really the most intuitive part, which are analogous to filtering your spiking data.

To begin, let's focus on the autonomous version of the architecture, which excludes the controller RNN. The data is put through nonlinear, recurrent encoders, and this produces an initial state distribution, which is a per-trial mean and variance to produce random vectors to encode that trial. The initial state of the generator is a randomly drawn vector from this distribution. The generator marches through time and at each time point produces factors and rates, ultimately producing outputs that learn to reproduce your data at the rate level.

From the perspective on information flow, the autonomous version of LFADS has a bottleneck between your data as inputted into LFADS, and the output, which also tries to learn your data. That bottleneck is the initial state of the generator, a potentially very low-bandwidth bottleneck, as a single vector has to encode a high-d time series. Such a system would be adequate for capturing systems that are (in approximation) autonomous. For example, motor cortex dynamics during center-out reaches seem extremely well approximated by autonomous dynamics at the sub-second time scale (e.g. Fig 2). However, if you were to perturb the reach by messing with the cursor the animal was using, e.g perturbing cursor location mid-reach, then the motor cortical dynamics of a corrected reach couldn't possibly be autonomous. In other words, some additional input must have come into the motor cortex and update the system with the information that the cursor had jumped unexpectedly. This is the experimental setting we setup in Fig. 5.

To compensate for such a scenario, we added a controller and inferred inputs to the generator portion of LFADS. In particular, the controller runs in sync with the generator and receives the output of the generator from the last time step (the only "backward" loop in the architecture, aside from using backprop for training with gradient descent). Thus it knows what the generator output. During training, the system learns that there are patterns in the data that cannot be created by the generator autonomously, so learns to compensate by emitting information from the data, through the encoders, through the controller to the generator. We call this information an inferred input. In our experimental setup, this worked well on two examples: messing with the cursor of an animal making a reach and also for inferring oscillations in the local field potential (LFP).

Please note that the inferred input system is extremely powerful as it provides a leak from your input data to the LFADS output on a per-time point basis. As such, one has to make sure that the system does not pathologically leak all the information from the data trial through LFADS to generate the data trial. LFADS, like all auto-encoders, is at risk of creating a trivial identity function, $x = f(x)$, rather than finding structure in the data. Thus, we utilize many tricks to avoid this (dropout, KL penalties, and even blocking out the information given to the controller from time step t, when decoding time step t.)

Hyperparameters


In [17]:
# LFADS Hyper parameters
data_dim = train_data.shape[2]  # input to lfads should have dimensions:
ntimesteps = train_data.shape[1] #   (batch_size x ntimesteps x data_dim)
batch_size = 128      # batch size during optimization

# LFADS architecture - The size of the numbers is rather arbitrary, 
# but relatively small because we know the integrator RNN isn't too high 
# dimensional in its activity.
enc_dim = 128         # encoder dim
con_dim = 128         # controller dim
ii_dim = 1            # inferred input dim, we know there is 1 dim in integrator RNN
gen_dim = 128         # generator dim, should be large enough to generate integrator RNN dynamics
factors_dim = 32      # factors dim, should be large enough to capture most variance of dynamics

# Numerical stability
var_min = 0.001 # Minimal variance any gaussian can become.

# Optimization HPs that percolates into model
l2reg = 0.00002

Hyperparameters for Priors

As was mentioned above, LFADS is an auto-encoder and auto-encoders typically encode data through some kind of information bottleneck. The idea is a lot like PCA, if one gets rid of unimportant variation, then perhaps meaningful and interesting structure in the data will become apparent.

More precisely, LFADS is a variational auto-encoder (VAE), which means that the bottleneck is achieved via probabilistic methods. Namely, each trial initial state is encoded in a per-trial Gaussian distribution called the 'posterior', e.g. initial state parameter's mean and variance are given by $(\mu(\mathbf{x}), \sigma^2(\mathbf{x}))$, where $\mathbf{x}$ is the data. This then is compared to an uninformative prior $(\mu_p, \sigma^2_p)$, where uninformative means the prior is independent of the data, including that trial. A type of distance for distributions is used, called the KL-divergence, to force the initial state Gaussian distribution for each trial to be as close to as possible to a Gaussian that doesn't depend on the trial. This is a part of the ELBO - Evidence Lower BOund - that is used to train VAEs.

In summary, one way of explaining VAEs is that they are auto-encoders, but they are attempting to limit the information flow from the input to the output using bottleneck based on probability distributions, basically forcing the generator to generate your data from white noise. This is doomed to fail if training works, but in the process, it learns a probabilistic generative model of your data.

In this LFADS architecture, there are two posterior distributions, based on the data, and two prior distributions, unrelated to the data. They are distributions for the initial state and the distributions for the inferred input.


In [18]:
# Initial state prior parameters
# the mean is set to zero in the code
ic_prior_var = 0.1 # this is $\sigma^2_p$ in above paragraph

Hyperparameters for inferred inputs

The inferred inputs are also codes represented by posterior distributions, but now each time point is a Gaussian, so each inferred input time series is really a Gaussian process. A natural uninformative prior to comparing the Gaussian process to is the autoregressive-1_process) process or AR-1 process for short.

$s_t = c + \phi s_{t-1} + \epsilon_t, \mbox{ with } \epsilon_t \in N(0, \sigma^2_n) $

with c the process mean, $\phi$ giving dependence of process state at time $t-1$ to process state at time $t$ and $\epsilon_t$ is the noise with variance $\sigma^2_n$. In LFADS, $c$ is always $0$.

So if you have 4 inferred inputs, then you have 4 AR-1 process priors. Utilizing an AR-1 process prior to sequences allows us to introduce another useful concept, the auto-correlation of each sequence. The auto-correlation is the correlation between values in the process at different time points. We are interested in auto-correlation because we may want to penalize very jagged or very smooth inferred inputs on a task by task case, as well as for other technical reasons. As it turns out, the input to the integrator RNN in this tutorial is uncorrelated white noise, so this concept is not too important, but in general it may be very important.

So just like the initial states, which introduced multi-variate Gaussian distributions (the posteriors) for each data trial and an uninformative prior to which the per-trial posteriors are compared, we do the same thing with inferred inputs, now using the KL-divergence to compare the distribution of auto-regressive sequences to uninformative AR-1 priors. In this way, we aim to limit how informative the inferred inputs are by introducing a bottleneck between the encoder and the generator.


In [19]:
# Inferred input autoregressive prior parameters
# Again, these hyper parameters are set "in the ballpark" but otherwise
# pretty randomly.
ar_mean = 0.0                 # process mean
ar_autocorrelation_tau = 1.0  # seconds, how correlated each time point is, related to $\phi$ above.
ar_noise_variance = 0.1       # noise variance

In [20]:
lfads_hps = {'data_dim' : data_dim, 'ntimesteps' : ntimesteps,
             'enc_dim' : enc_dim, 'con_dim' : con_dim, 'var_min' : var_min,
             'ic_prior_var' : ic_prior_var, 'ar_mean' : ar_mean,
             'ar_autocorrelation_tau' : ar_autocorrelation_tau,
             'ar_noise_variance' : ar_noise_variance,
             'ii_dim' : ii_dim, 'gen_dim' : gen_dim,
             'factors_dim' : factors_dim,
             'l2reg' : l2reg,
             'batch_size' : batch_size}

LFADS Optimization hyperparameters


In [21]:
num_batches = 20000         # how many batches do we train
print_every = 100            # give information every so often

# Learning rate HPs
step_size = 0.05            # initial learning rate
decay_factor = 0.9999      # learning rate decay param
decay_steps = 1             # learning rate decay param

# Regularization HPs
keep_rate = 0.98            # dropout keep rate during training

# Numerical stability HPs
max_grad_norm = 10.0        # gradient clipping above this value

Warming up the KL penalties

The optimization of a VAE optimizes the ELBO, which is

$L(\theta) = -\mathop{\mathbb{E}}_x \left(\log p_\theta(x|z) + KL(q_\theta(z|x) \;\;|| \;\;p(z))\right)$

  • $p_\theta(x|z)$ - the reconstruction given the initial state and inferred inputs distributions (collectively denoted $z$ here)

  • $q_\theta(z|x)$ - represents the latent variable posterior distributions (the data encoders that ultimately yield the initial state and inferred input codes).

  • $p(z)$ - the prior that does not know about the data

where $\theta$ are all the trainable parameters. This is an expectation over all your data, $x$, of the quality of the data generation $p_\theta(x|z)$, plus the KL divergence penalty mentioned above that compares the distributions for the initial state and inferred inputs to uninformative priors.

All the hacks in hacksville: It turns out that the KL term can be a lot easier to optimize initially than learning how to reconstruct your data. This results in a pathological stoppage of training where the KL goes to nearly zero and training is broken there on out (as you cannot represent any a given trial from uninformative priors). One way out of this is to warm up the KL penalty, starting it off with a weight term of 0 and then slowly building to 1, giving the reconstruction a chance to train a bit without the KL penalty messing things up.


In [22]:
# The fact that the start and end values are required to be floats is something I need to fix.
kl_warmup_start = 500.0 # batch number to start KL warm-up, explicitly float
kl_warmup_end = 1000.0  # batch number to be finished with KL warm-up, explicitly float
kl_min = 0.01 # The minimum KL value, non-zero to make sure KL doesn't grow crazy before kicking in.

Note, there is currently a HUGE amount of debate about what the correct parameter value here is for the KL penalty. kl_max = 1 is what creates a lower bound on the (marginal) log likelihood of the data, but folks argue it could be higher or lower than 1. Myself, I have never played around with this HP, but I have the idea that LFADS may benefit from < 1 values, as LFADS is not really being used for random sampling from the distribution of spiking data.

See $\beta$-VAE: LEARNING BASIC VISUAL CONCEPTS WITH A CONSTRAINED VARIATIONAL FRAMEWORK

See Fixing a Broken ELBO as to why you might choose a particular KL maximum value. I found this article pretty clarifying.


In [23]:
kl_max = 1.0

In [24]:
lfads_opt_hps = {'num_batches' : num_batches, 'step_size' : step_size,
                 'decay_steps' : decay_steps, 'decay_factor' : decay_factor,
                 'kl_min' : kl_min, 'kl_max' : kl_max, 'kl_warmup_start' : kl_warmup_start,
                 'kl_warmup_end' : kl_warmup_end, 'keep_rate' : keep_rate,
                 'max_grad_norm' : max_grad_norm, 'print_every' : print_every,
                 'adam_b1' : 0.9, 'adam_b2' : 0.999, 'adam_eps' : 1e-1}

assert num_batches >= print_every and num_batches % print_every == 0

In [25]:
# Plot the warmup function and the learning rate decay function.
plt.figure(figsize=(16,4))
plt.subplot(121)
x = onp.arange(0, num_batches, print_every)
kl_warmup_fun = get_kl_warmup_fun(lfads_opt_hps)
plt.plot(x, [kl_warmup_fun(i) for i in onp.arange(1,lfads_opt_hps['num_batches'], print_every)]);
plt.title('KL warmup function')
plt.xlabel('Training batch');

plt.subplot(122)
decay_fun = optimizers.exponential_decay(lfads_opt_hps['step_size'],                                                             
                                         lfads_opt_hps['decay_steps'],                                                           
                                         lfads_opt_hps['decay_factor'])                                                          
plt.plot(x, [decay_fun(i) for i in range(1, lfads_opt_hps['num_batches'], print_every)]);
plt.title('learning rate function')
plt.xlabel('Training batch');


Train the LFADS model

Note that JAX uses its own setup to handle randomness and seeding the pseudo-random number generators. You can read about it here. If you want to modify the LFADS tutorial you NEED to understand this. Otherwise, not so big a deal if you are just messing around with LFADS hyperparameters or applying the tutorial to new data.


In [26]:
# Initialize parameters for LFADS
key = random.PRNGKey(onp.random.randint(0, utils.MAX_SEED_INT))
init_params = lfads.lfads_params(key, lfads_hps)

Note that the first loop could take a few minutes to run, because the LFADS model is unrolled, and therefor the JIT (just in time) compilation is slow, and happens "just in time", which is the first training loop iteration. On my computer, the JIT compilation takes a few minutes.

You'll see the loss go up when the KL warmup starts turning on.


In [27]:
key = random.PRNGKey(onp.random.randint(0, utils.MAX_SEED_INT))
trained_params, opt_details = \
    optimize_lfads(key, init_params, lfads_hps, lfads_opt_hps,
                   train_data, eval_data)


Batches 1-100 in 57.34 sec, Step size: 0.04950
    Training losses 2119 = NLL 2116 + KL IC 178,2 + KL II 75,1 + L2 0.09
        Eval losses 2117 = NLL 2115 + KL IC 175,2 + KL II 75,1 + L2 0.09
Batches 101-200 in 5.25 sec, Step size: 0.04901
    Training losses 2122 = NLL 2119 + KL IC 196,2 + KL II 98,1 + L2 0.09
        Eval losses 2121 = NLL 2118 + KL IC 194,2 + KL II 98,1 + L2 0.09
Batches 201-300 in 5.25 sec, Step size: 0.04852
    Training losses 2147 = NLL 2144 + KL IC 203,2 + KL II 101,1 + L2 0.10
        Eval losses 2144 = NLL 2141 + KL IC 200,2 + KL II 101,1 + L2 0.10
Batches 301-400 in 5.25 sec, Step size: 0.04804
    Training losses 2172 = NLL 2169 + KL IC 203,2 + KL II 100,1 + L2 0.10
        Eval losses 2166 = NLL 2163 + KL IC 201,2 + KL II 100,1 + L2 0.10
Batches 401-500 in 5.22 sec, Step size: 0.04756
    Training losses 2171 = NLL 2168 + KL IC 204,2 + KL II 97,1 + L2 0.11
        Eval losses 2169 = NLL 2166 + KL IC 204,2 + KL II 97,1 + L2 0.11
Batches 501-600 in 5.22 sec, Step size: 0.04709
    Training losses 2157 = NLL 2156 + KL IC 63,1 + KL II 60,1 + L2 0.11
        Eval losses 2166 = NLL 2164 + KL IC 62,1 + KL II 57,1 + L2 0.11
Batches 601-700 in 5.26 sec, Step size: 0.04662
    Training losses 2158 = NLL 2145 + KL IC 25,5 + KL II 36,8 + L2 0.11
        Eval losses 2156 = NLL 2143 + KL IC 25,5 + KL II 36,8 + L2 0.11
Batches 701-800 in 5.21 sec, Step size: 0.04616
    Training losses 2161 = NLL 2147 + KL IC 12,5 + KL II 21,9 + L2 0.11
        Eval losses 2171 = NLL 2158 + KL IC 11,5 + KL II 21,9 + L2 0.11
Batches 801-900 in 5.26 sec, Step size: 0.04570
    Training losses 2173 = NLL 2158 + KL IC 8,5 + KL II 16,9 + L2 0.11
        Eval losses 2170 = NLL 2155 + KL IC 8,5 + KL II 16,10 + L2 0.11
Batches 901-1000 in 5.23 sec, Step size: 0.04524
    Training losses 2186 = NLL 2168 + KL IC 7,5 + KL II 16,13 + L2 0.12
        Eval losses 2186 = NLL 2167 + KL IC 6,5 + KL II 17,14 + L2 0.12
Batches 1001-1100 in 5.16 sec, Step size: 0.04479
    Training losses 2182 = NLL 2163 + KL IC 5,5 + KL II 14,14 + L2 0.12
        Eval losses 2181 = NLL 2163 + KL IC 5,5 + KL II 13,13 + L2 0.12
Batches 1101-1200 in 5.17 sec, Step size: 0.04435
    Training losses 2177 = NLL 2159 + KL IC 4,4 + KL II 14,14 + L2 0.12
        Eval losses 2180 = NLL 2162 + KL IC 4,4 + KL II 15,15 + L2 0.12
Batches 1201-1300 in 5.23 sec, Step size: 0.04390
    Training losses 2185 = NLL 2166 + KL IC 3,3 + KL II 15,15 + L2 0.12
        Eval losses 2186 = NLL 2168 + KL IC 3,3 + KL II 15,15 + L2 0.12
Batches 1301-1400 in 5.19 sec, Step size: 0.04347
    Training losses 2185 = NLL 2168 + KL IC 3,3 + KL II 15,15 + L2 0.13
        Eval losses 2187 = NLL 2168 + KL IC 3,3 + KL II 15,15 + L2 0.13
Batches 1401-1500 in 5.21 sec, Step size: 0.04304
    Training losses 2186 = NLL 2163 + KL IC 4,4 + KL II 20,20 + L2 0.13
        Eval losses 2190 = NLL 2167 + KL IC 4,4 + KL II 20,20 + L2 0.13
Batches 1501-1600 in 5.25 sec, Step size: 0.04261
    Training losses 2173 = NLL 2152 + KL IC 3,3 + KL II 19,19 + L2 0.13
        Eval losses 2170 = NLL 2149 + KL IC 2,2 + KL II 19,19 + L2 0.13
Batches 1601-1700 in 5.24 sec, Step size: 0.04218
    Training losses 2158 = NLL 2141 + KL IC 2,2 + KL II 15,15 + L2 0.13
        Eval losses 2164 = NLL 2144 + KL IC 2,2 + KL II 18,18 + L2 0.13
Batches 1701-1800 in 5.25 sec, Step size: 0.04176
    Training losses 2161 = NLL 2141 + KL IC 2,2 + KL II 18,18 + L2 0.14
        Eval losses 2154 = NLL 2136 + KL IC 2,2 + KL II 16,16 + L2 0.14
Batches 1801-1900 in 5.21 sec, Step size: 0.04135
    Training losses 2157 = NLL 2138 + KL IC 2,2 + KL II 17,17 + L2 0.14
        Eval losses 2158 = NLL 2139 + KL IC 2,2 + KL II 17,17 + L2 0.14
Batches 1901-2000 in 5.21 sec, Step size: 0.04094
    Training losses 2167 = NLL 2150 + KL IC 2,2 + KL II 15,15 + L2 0.14
        Eval losses 2168 = NLL 2152 + KL IC 2,2 + KL II 15,15 + L2 0.14
Batches 2001-2100 in 5.20 sec, Step size: 0.04053
    Training losses 2136 = NLL 2120 + KL IC 1,1 + KL II 15,15 + L2 0.14
        Eval losses 2136 = NLL 2119 + KL IC 1,1 + KL II 15,15 + L2 0.14
Batches 2101-2200 in 5.19 sec, Step size: 0.04013
    Training losses 2145 = NLL 2125 + KL IC 1,1 + KL II 19,19 + L2 0.14
        Eval losses 2151 = NLL 2132 + KL IC 1,1 + KL II 17,17 + L2 0.14
Batches 2201-2300 in 5.21 sec, Step size: 0.03973
    Training losses 2151 = NLL 2133 + KL IC 1,1 + KL II 17,17 + L2 0.14
        Eval losses 2139 = NLL 2122 + KL IC 1,1 + KL II 16,16 + L2 0.14
Batches 2301-2400 in 5.25 sec, Step size: 0.03933
    Training losses 2139 = NLL 2123 + KL IC 1,1 + KL II 15,15 + L2 0.15
        Eval losses 2144 = NLL 2127 + KL IC 1,1 + KL II 15,15 + L2 0.15
Batches 2401-2500 in 5.04 sec, Step size: 0.03894
    Training losses 2148 = NLL 2128 + KL IC 1,1 + KL II 19,19 + L2 0.15
        Eval losses 2158 = NLL 2138 + KL IC 1,1 + KL II 19,19 + L2 0.15
Batches 2501-2600 in 4.90 sec, Step size: 0.03855
    Training losses 2141 = NLL 2122 + KL IC 1,1 + KL II 18,18 + L2 0.15
        Eval losses 2143 = NLL 2124 + KL IC 1,1 + KL II 18,18 + L2 0.15
Batches 2601-2700 in 5.17 sec, Step size: 0.03817
    Training losses 2141 = NLL 2122 + KL IC 1,1 + KL II 18,18 + L2 0.15
        Eval losses 2134 = NLL 2117 + KL IC 1,1 + KL II 16,16 + L2 0.15
Batches 2701-2800 in 4.95 sec, Step size: 0.03779
    Training losses 2131 = NLL 2113 + KL IC 1,1 + KL II 17,17 + L2 0.15
        Eval losses 2143 = NLL 2124 + KL IC 1,1 + KL II 17,17 + L2 0.15
Batches 2801-2900 in 4.96 sec, Step size: 0.03741
    Training losses 2140 = NLL 2117 + KL IC 1,1 + KL II 22,22 + L2 0.15
        Eval losses 2148 = NLL 2123 + KL IC 1,1 + KL II 24,24 + L2 0.15
Batches 2901-3000 in 4.90 sec, Step size: 0.03704
    Training losses 2136 = NLL 2117 + KL IC 1,1 + KL II 19,19 + L2 0.15
        Eval losses 2132 = NLL 2113 + KL IC 1,1 + KL II 18,18 + L2 0.15
Batches 3001-3100 in 4.96 sec, Step size: 0.03667
    Training losses 2132 = NLL 2115 + KL IC 1,1 + KL II 16,16 + L2 0.15
        Eval losses 2131 = NLL 2114 + KL IC 1,1 + KL II 16,16 + L2 0.15
Batches 3101-3200 in 4.94 sec, Step size: 0.03631
    Training losses 2139 = NLL 2121 + KL IC 1,1 + KL II 17,17 + L2 0.15
        Eval losses 2136 = NLL 2120 + KL IC 1,1 + KL II 16,16 + L2 0.15
Batches 3201-3300 in 4.93 sec, Step size: 0.03595
    Training losses 2130 = NLL 2115 + KL IC 1,1 + KL II 14,14 + L2 0.16
        Eval losses 2126 = NLL 2112 + KL IC 1,1 + KL II 13,13 + L2 0.16
Batches 3301-3400 in 4.97 sec, Step size: 0.03559
    Training losses 2126 = NLL 2112 + KL IC 1,1 + KL II 13,13 + L2 0.16
        Eval losses 2130 = NLL 2115 + KL IC 1,1 + KL II 14,14 + L2 0.16
Batches 3401-3500 in 4.97 sec, Step size: 0.03523
    Training losses 2126 = NLL 2111 + KL IC 1,1 + KL II 14,14 + L2 0.16
        Eval losses 2129 = NLL 2114 + KL IC 1,1 + KL II 14,14 + L2 0.16
Batches 3501-3600 in 4.96 sec, Step size: 0.03488
    Training losses 2129 = NLL 2115 + KL IC 1,1 + KL II 14,14 + L2 0.16
        Eval losses 2134 = NLL 2120 + KL IC 1,1 + KL II 14,14 + L2 0.16
Batches 3601-3700 in 4.97 sec, Step size: 0.03454
    Training losses 2123 = NLL 2110 + KL IC 1,1 + KL II 13,13 + L2 0.16
        Eval losses 2127 = NLL 2114 + KL IC 1,1 + KL II 13,13 + L2 0.16
Batches 3701-3800 in 4.98 sec, Step size: 0.03419
    Training losses 2130 = NLL 2114 + KL IC 0,0 + KL II 15,15 + L2 0.16
        Eval losses 2130 = NLL 2112 + KL IC 0,0 + KL II 17,17 + L2 0.16
Batches 3801-3900 in 4.95 sec, Step size: 0.03385
    Training losses 2135 = NLL 2117 + KL IC 1,1 + KL II 17,17 + L2 0.16
        Eval losses 2129 = NLL 2114 + KL IC 0,0 + KL II 14,14 + L2 0.16
Batches 3901-4000 in 4.93 sec, Step size: 0.03352
    Training losses 2131 = NLL 2117 + KL IC 0,0 + KL II 13,13 + L2 0.16
        Eval losses 2129 = NLL 2115 + KL IC 0,0 + KL II 14,14 + L2 0.16
Batches 4001-4100 in 4.97 sec, Step size: 0.03318
    Training losses 2133 = NLL 2119 + KL IC 0,0 + KL II 14,14 + L2 0.16
        Eval losses 2135 = NLL 2120 + KL IC 0,0 + KL II 14,14 + L2 0.16
Batches 4101-4200 in 4.94 sec, Step size: 0.03285
    Training losses 2124 = NLL 2108 + KL IC 0,0 + KL II 15,15 + L2 0.16
        Eval losses 2128 = NLL 2112 + KL IC 0,0 + KL II 15,15 + L2 0.16
Batches 4201-4300 in 4.93 sec, Step size: 0.03252
    Training losses 2127 = NLL 2115 + KL IC 0,0 + KL II 12,12 + L2 0.16
        Eval losses 2117 = NLL 2105 + KL IC 0,0 + KL II 12,12 + L2 0.16
Batches 4301-4400 in 4.99 sec, Step size: 0.03220
    Training losses 2125 = NLL 2109 + KL IC 0,0 + KL II 15,15 + L2 0.16
        Eval losses 2125 = NLL 2109 + KL IC 0,0 + KL II 15,15 + L2 0.16
Batches 4401-4500 in 4.94 sec, Step size: 0.03188
    Training losses 2127 = NLL 2114 + KL IC 0,0 + KL II 13,13 + L2 0.17
        Eval losses 2118 = NLL 2104 + KL IC 0,0 + KL II 13,13 + L2 0.17
Batches 4501-4600 in 5.01 sec, Step size: 0.03156
    Training losses 2124 = NLL 2107 + KL IC 0,0 + KL II 16,16 + L2 0.17
        Eval losses 2116 = NLL 2100 + KL IC 0,0 + KL II 16,16 + L2 0.17
Batches 4601-4700 in 4.96 sec, Step size: 0.03125
    Training losses 2124 = NLL 2109 + KL IC 0,0 + KL II 14,14 + L2 0.17
        Eval losses 2119 = NLL 2104 + KL IC 0,0 + KL II 15,15 + L2 0.17
Batches 4701-4800 in 4.96 sec, Step size: 0.03094
    Training losses 2124 = NLL 2109 + KL IC 0,0 + KL II 15,15 + L2 0.17
        Eval losses 2124 = NLL 2109 + KL IC 0,0 + KL II 14,14 + L2 0.17
Batches 4801-4900 in 4.97 sec, Step size: 0.03063
    Training losses 2125 = NLL 2111 + KL IC 0,0 + KL II 14,14 + L2 0.17
        Eval losses 2120 = NLL 2106 + KL IC 0,0 + KL II 14,14 + L2 0.17
Batches 4901-5000 in 4.93 sec, Step size: 0.03033
    Training losses 2117 = NLL 2104 + KL IC 0,0 + KL II 13,13 + L2 0.17
        Eval losses 2120 = NLL 2106 + KL IC 0,0 + KL II 14,14 + L2 0.17
Batches 5001-5100 in 4.99 sec, Step size: 0.03002
    Training losses 2120 = NLL 2106 + KL IC 0,0 + KL II 13,13 + L2 0.17
        Eval losses 2122 = NLL 2110 + KL IC 0,0 + KL II 12,12 + L2 0.17
Batches 5101-5200 in 4.94 sec, Step size: 0.02973
    Training losses 2115 = NLL 2101 + KL IC 0,0 + KL II 14,14 + L2 0.17
        Eval losses 2122 = NLL 2109 + KL IC 0,0 + KL II 13,13 + L2 0.17
Batches 5201-5300 in 4.98 sec, Step size: 0.02943
    Training losses 2124 = NLL 2110 + KL IC 0,0 + KL II 14,14 + L2 0.17
        Eval losses 2121 = NLL 2107 + KL IC 0,0 + KL II 13,13 + L2 0.17
Batches 5301-5400 in 4.96 sec, Step size: 0.02914
    Training losses 2116 = NLL 2104 + KL IC 0,0 + KL II 12,12 + L2 0.17
        Eval losses 2115 = NLL 2102 + KL IC 0,0 + KL II 13,13 + L2 0.17
Batches 5401-5500 in 4.92 sec, Step size: 0.02885
    Training losses 2122 = NLL 2108 + KL IC 0,0 + KL II 14,14 + L2 0.17
        Eval losses 2117 = NLL 2103 + KL IC 0,0 + KL II 14,14 + L2 0.17
Batches 5501-5600 in 4.96 sec, Step size: 0.02856
    Training losses 2115 = NLL 2102 + KL IC 0,0 + KL II 13,13 + L2 0.17
        Eval losses 2117 = NLL 2103 + KL IC 0,0 + KL II 13,13 + L2 0.17
Batches 5601-5700 in 4.99 sec, Step size: 0.02828
    Training losses 2123 = NLL 2109 + KL IC 0,0 + KL II 13,13 + L2 0.17
        Eval losses 2123 = NLL 2108 + KL IC 0,0 + KL II 14,14 + L2 0.17
Batches 5701-5800 in 4.97 sec, Step size: 0.02799
    Training losses 2118 = NLL 2103 + KL IC 0,0 + KL II 14,14 + L2 0.17
        Eval losses 2119 = NLL 2104 + KL IC 0,0 + KL II 14,14 + L2 0.17
Batches 5801-5900 in 4.95 sec, Step size: 0.02772
    Training losses 2119 = NLL 2105 + KL IC 0,0 + KL II 14,14 + L2 0.17
        Eval losses 2119 = NLL 2106 + KL IC 0,0 + KL II 13,13 + L2 0.17
Batches 5901-6000 in 4.96 sec, Step size: 0.02744
    Training losses 2120 = NLL 2105 + KL IC 0,0 + KL II 15,15 + L2 0.17
        Eval losses 2113 = NLL 2099 + KL IC 0,0 + KL II 14,14 + L2 0.17
Batches 6001-6100 in 4.96 sec, Step size: 0.02717
    Training losses 2122 = NLL 2107 + KL IC 0,0 + KL II 14,14 + L2 0.17
        Eval losses 2111 = NLL 2096 + KL IC 0,0 + KL II 15,15 + L2 0.17
Batches 6101-6200 in 4.96 sec, Step size: 0.02690
    Training losses 2119 = NLL 2103 + KL IC 0,0 + KL II 16,16 + L2 0.17
        Eval losses 2120 = NLL 2106 + KL IC 0,0 + KL II 14,14 + L2 0.17
Batches 6201-6300 in 4.96 sec, Step size: 0.02663
    Training losses 2120 = NLL 2105 + KL IC 0,0 + KL II 14,14 + L2 0.17
        Eval losses 2118 = NLL 2104 + KL IC 0,0 + KL II 14,14 + L2 0.17
Batches 6301-6400 in 4.96 sec, Step size: 0.02636
    Training losses 2126 = NLL 2110 + KL IC 0,0 + KL II 16,16 + L2 0.17
        Eval losses 2126 = NLL 2110 + KL IC 0,0 + KL II 16,16 + L2 0.17
Batches 6401-6500 in 4.98 sec, Step size: 0.02610
    Training losses 2117 = NLL 2102 + KL IC 0,0 + KL II 14,14 + L2 0.17
        Eval losses 2119 = NLL 2104 + KL IC 0,0 + KL II 14,14 + L2 0.17
Batches 6501-6600 in 4.95 sec, Step size: 0.02584
    Training losses 2117 = NLL 2101 + KL IC 0,0 + KL II 15,15 + L2 0.17
        Eval losses 2119 = NLL 2104 + KL IC 0,0 + KL II 14,14 + L2 0.17
Batches 6601-6700 in 4.95 sec, Step size: 0.02558
    Training losses 2119 = NLL 2103 + KL IC 0,0 + KL II 15,15 + L2 0.17
        Eval losses 2122 = NLL 2106 + KL IC 0,0 + KL II 15,15 + L2 0.17
Batches 6701-6800 in 4.95 sec, Step size: 0.02533
    Training losses 2116 = NLL 2102 + KL IC 0,0 + KL II 13,13 + L2 0.17
        Eval losses 2113 = NLL 2100 + KL IC 0,0 + KL II 13,13 + L2 0.17
Batches 6801-6900 in 4.97 sec, Step size: 0.02508
    Training losses 2118 = NLL 2105 + KL IC 0,0 + KL II 13,13 + L2 0.17
        Eval losses 2119 = NLL 2105 + KL IC 0,0 + KL II 13,13 + L2 0.17
Batches 6901-7000 in 4.98 sec, Step size: 0.02483
    Training losses 2115 = NLL 2101 + KL IC 0,0 + KL II 14,14 + L2 0.17
        Eval losses 2119 = NLL 2105 + KL IC 0,0 + KL II 13,13 + L2 0.17
Batches 7001-7100 in 4.94 sec, Step size: 0.02458
    Training losses 2113 = NLL 2099 + KL IC 0,0 + KL II 14,14 + L2 0.17
        Eval losses 2115 = NLL 2102 + KL IC 0,0 + KL II 13,13 + L2 0.17
Batches 7101-7200 in 4.97 sec, Step size: 0.02434
    Training losses 2106 = NLL 2093 + KL IC 0,0 + KL II 13,13 + L2 0.17
        Eval losses 2114 = NLL 2100 + KL IC 0,0 + KL II 14,14 + L2 0.17
Batches 7201-7300 in 4.97 sec, Step size: 0.02409
    Training losses 2111 = NLL 2096 + KL IC 0,0 + KL II 15,15 + L2 0.18
        Eval losses 2115 = NLL 2100 + KL IC 0,0 + KL II 14,14 + L2 0.18
Batches 7301-7400 in 5.00 sec, Step size: 0.02385
    Training losses 2118 = NLL 2101 + KL IC 0,0 + KL II 16,16 + L2 0.18
        Eval losses 2119 = NLL 2102 + KL IC 0,0 + KL II 17,17 + L2 0.18
Batches 7401-7500 in 4.96 sec, Step size: 0.02362
    Training losses 2114 = NLL 2100 + KL IC 0,0 + KL II 14,14 + L2 0.18
        Eval losses 2115 = NLL 2101 + KL IC 0,0 + KL II 13,13 + L2 0.18
Batches 7501-7600 in 4.97 sec, Step size: 0.02338
    Training losses 2114 = NLL 2102 + KL IC 0,0 + KL II 12,12 + L2 0.18
        Eval losses 2116 = NLL 2103 + KL IC 0,0 + KL II 13,13 + L2 0.18
Batches 7601-7700 in 4.97 sec, Step size: 0.02315
    Training losses 2103 = NLL 2091 + KL IC 0,0 + KL II 12,12 + L2 0.18
        Eval losses 2107 = NLL 2094 + KL IC 0,0 + KL II 12,12 + L2 0.18
Batches 7701-7800 in 4.99 sec, Step size: 0.02292
    Training losses 2109 = NLL 2096 + KL IC 0,0 + KL II 13,13 + L2 0.18
        Eval losses 2116 = NLL 2102 + KL IC 0,0 + KL II 13,13 + L2 0.18
Batches 7801-7900 in 4.93 sec, Step size: 0.02269
    Training losses 2110 = NLL 2096 + KL IC 0,0 + KL II 14,14 + L2 0.18
        Eval losses 2113 = NLL 2099 + KL IC 0,0 + KL II 14,14 + L2 0.18
Batches 7901-8000 in 4.91 sec, Step size: 0.02247
    Training losses 2109 = NLL 2095 + KL IC 0,0 + KL II 14,14 + L2 0.18
        Eval losses 2106 = NLL 2091 + KL IC 0,0 + KL II 15,15 + L2 0.18
Batches 8001-8100 in 4.94 sec, Step size: 0.02224
    Training losses 2108 = NLL 2093 + KL IC 0,0 + KL II 14,14 + L2 0.18
        Eval losses 2107 = NLL 2092 + KL IC 0,0 + KL II 15,15 + L2 0.18
Batches 8101-8200 in 4.89 sec, Step size: 0.02202
    Training losses 2114 = NLL 2099 + KL IC 0,0 + KL II 15,15 + L2 0.18
        Eval losses 2114 = NLL 2099 + KL IC 0,0 + KL II 15,15 + L2 0.18
Batches 8201-8300 in 4.94 sec, Step size: 0.02180
    Training losses 2107 = NLL 2093 + KL IC 0,0 + KL II 14,14 + L2 0.18
        Eval losses 2108 = NLL 2093 + KL IC 0,0 + KL II 14,14 + L2 0.18
Batches 8301-8400 in 4.97 sec, Step size: 0.02158
    Training losses 2114 = NLL 2098 + KL IC 0,0 + KL II 15,15 + L2 0.18
        Eval losses 2108 = NLL 2093 + KL IC 0,0 + KL II 15,15 + L2 0.18
Batches 8401-8500 in 5.12 sec, Step size: 0.02137
    Training losses 2114 = NLL 2100 + KL IC 0,0 + KL II 14,14 + L2 0.18
        Eval losses 2114 = NLL 2101 + KL IC 0,0 + KL II 13,13 + L2 0.18
Batches 8501-8600 in 4.97 sec, Step size: 0.02116
    Training losses 2117 = NLL 2101 + KL IC 0,0 + KL II 16,16 + L2 0.18
        Eval losses 2116 = NLL 2101 + KL IC 0,0 + KL II 15,15 + L2 0.18
Batches 8601-8700 in 4.96 sec, Step size: 0.02095
    Training losses 2112 = NLL 2098 + KL IC 0,0 + KL II 14,14 + L2 0.18
        Eval losses 2112 = NLL 2098 + KL IC 0,0 + KL II 14,14 + L2 0.18
Batches 8701-8800 in 5.00 sec, Step size: 0.02074
    Training losses 2109 = NLL 2092 + KL IC 0,0 + KL II 17,17 + L2 0.18
        Eval losses 2122 = NLL 2104 + KL IC 0,0 + KL II 18,18 + L2 0.18
Batches 8801-8900 in 4.98 sec, Step size: 0.02053
    Training losses 2113 = NLL 2097 + KL IC 0,0 + KL II 16,16 + L2 0.18
        Eval losses 2111 = NLL 2096 + KL IC 0,0 + KL II 15,15 + L2 0.18
Batches 8901-9000 in 4.94 sec, Step size: 0.02033
    Training losses 2113 = NLL 2097 + KL IC 0,0 + KL II 15,15 + L2 0.18
        Eval losses 2119 = NLL 2103 + KL IC 0,0 + KL II 15,15 + L2 0.18
Batches 9001-9100 in 4.96 sec, Step size: 0.02013
    Training losses 2114 = NLL 2098 + KL IC 0,0 + KL II 15,15 + L2 0.18
        Eval losses 2118 = NLL 2101 + KL IC 0,0 + KL II 16,16 + L2 0.18
Batches 9101-9200 in 4.96 sec, Step size: 0.01993
    Training losses 2116 = NLL 2100 + KL IC 1,1 + KL II 15,15 + L2 0.18
        Eval losses 2112 = NLL 2095 + KL IC 1,1 + KL II 16,16 + L2 0.18
Batches 9201-9300 in 4.97 sec, Step size: 0.01973
    Training losses 2108 = NLL 2092 + KL IC 0,0 + KL II 16,16 + L2 0.18
        Eval losses 2116 = NLL 2101 + KL IC 0,0 + KL II 15,15 + L2 0.18
Batches 9301-9400 in 4.96 sec, Step size: 0.01953
    Training losses 2118 = NLL 2103 + KL IC 0,0 + KL II 15,15 + L2 0.18
        Eval losses 2112 = NLL 2096 + KL IC 0,0 + KL II 15,15 + L2 0.18
Batches 9401-9500 in 4.94 sec, Step size: 0.01934
    Training losses 2112 = NLL 2094 + KL IC 0,0 + KL II 17,17 + L2 0.18
        Eval losses 2112 = NLL 2095 + KL IC 0,0 + KL II 16,16 + L2 0.18
Batches 9501-9600 in 4.98 sec, Step size: 0.01914
    Training losses 2117 = NLL 2102 + KL IC 0,0 + KL II 15,15 + L2 0.18
        Eval losses 2112 = NLL 2096 + KL IC 0,0 + KL II 16,16 + L2 0.18
Batches 9601-9700 in 4.96 sec, Step size: 0.01895
    Training losses 2112 = NLL 2095 + KL IC 0,0 + KL II 16,16 + L2 0.18
        Eval losses 2113 = NLL 2096 + KL IC 0,0 + KL II 16,16 + L2 0.18
Batches 9701-9800 in 4.96 sec, Step size: 0.01876
    Training losses 2113 = NLL 2098 + KL IC 0,0 + KL II 14,14 + L2 0.18
        Eval losses 2108 = NLL 2094 + KL IC 0,0 + KL II 14,14 + L2 0.18
Batches 9801-9900 in 4.96 sec, Step size: 0.01858
    Training losses 2107 = NLL 2093 + KL IC 0,0 + KL II 13,13 + L2 0.18
        Eval losses 2112 = NLL 2099 + KL IC 0,0 + KL II 13,13 + L2 0.18
Batches 9901-10000 in 5.01 sec, Step size: 0.01839
    Training losses 2109 = NLL 2096 + KL IC 0,0 + KL II 13,13 + L2 0.18
        Eval losses 2113 = NLL 2100 + KL IC 0,0 + KL II 13,13 + L2 0.18
Batches 10001-10100 in 4.95 sec, Step size: 0.01821
    Training losses 2112 = NLL 2097 + KL IC 0,0 + KL II 15,15 + L2 0.18
        Eval losses 2108 = NLL 2093 + KL IC 0,0 + KL II 15,15 + L2 0.18
Batches 10101-10200 in 4.97 sec, Step size: 0.01803
    Training losses 2108 = NLL 2093 + KL IC 0,0 + KL II 14,14 + L2 0.18
        Eval losses 2113 = NLL 2098 + KL IC 0,0 + KL II 14,14 + L2 0.18
Batches 10201-10300 in 4.94 sec, Step size: 0.01785
    Training losses 2113 = NLL 2099 + KL IC 0,0 + KL II 14,14 + L2 0.18
        Eval losses 2112 = NLL 2097 + KL IC 0,0 + KL II 14,14 + L2 0.18
Batches 10301-10400 in 4.93 sec, Step size: 0.01767
    Training losses 2114 = NLL 2099 + KL IC 0,0 + KL II 14,14 + L2 0.18
        Eval losses 2114 = NLL 2099 + KL IC 0,0 + KL II 14,14 + L2 0.18
Batches 10401-10500 in 4.98 sec, Step size: 0.01750
    Training losses 2116 = NLL 2099 + KL IC 0,0 + KL II 16,16 + L2 0.18
        Eval losses 2116 = NLL 2099 + KL IC 0,0 + KL II 17,17 + L2 0.18
Batches 10501-10600 in 4.92 sec, Step size: 0.01732
    Training losses 2109 = NLL 2092 + KL IC 0,0 + KL II 16,16 + L2 0.18
        Eval losses 2110 = NLL 2094 + KL IC 0,0 + KL II 16,16 + L2 0.18
Batches 10601-10700 in 4.95 sec, Step size: 0.01715
    Training losses 2115 = NLL 2096 + KL IC 0,0 + KL II 18,18 + L2 0.18
        Eval losses 2110 = NLL 2092 + KL IC 0,0 + KL II 18,18 + L2 0.18
Batches 10701-10800 in 4.96 sec, Step size: 0.01698
    Training losses 2110 = NLL 2094 + KL IC 0,0 + KL II 15,15 + L2 0.18
        Eval losses 2117 = NLL 2102 + KL IC 0,0 + KL II 15,15 + L2 0.18
Batches 10801-10900 in 4.97 sec, Step size: 0.01681
    Training losses 2112 = NLL 2094 + KL IC 0,0 + KL II 18,18 + L2 0.18
        Eval losses 2114 = NLL 2096 + KL IC 0,0 + KL II 18,18 + L2 0.18
Batches 10901-11000 in 4.98 sec, Step size: 0.01664
    Training losses 2116 = NLL 2099 + KL IC 0,0 + KL II 17,17 + L2 0.18
        Eval losses 2112 = NLL 2096 + KL IC 0,0 + KL II 16,16 + L2 0.18
Batches 11001-11100 in 4.99 sec, Step size: 0.01648
    Training losses 2112 = NLL 2097 + KL IC 0,0 + KL II 15,15 + L2 0.18
        Eval losses 2107 = NLL 2091 + KL IC 0,0 + KL II 15,15 + L2 0.18
Batches 11101-11200 in 4.99 sec, Step size: 0.01631
    Training losses 2119 = NLL 2100 + KL IC 0,0 + KL II 19,19 + L2 0.18
        Eval losses 2118 = NLL 2099 + KL IC 0,0 + KL II 18,18 + L2 0.18
Batches 11201-11300 in 4.99 sec, Step size: 0.01615
    Training losses 2112 = NLL 2095 + KL IC 0,0 + KL II 16,16 + L2 0.18
        Eval losses 2114 = NLL 2098 + KL IC 0,0 + KL II 16,16 + L2 0.18
Batches 11301-11400 in 4.99 sec, Step size: 0.01599
    Training losses 2106 = NLL 2089 + KL IC 0,0 + KL II 17,17 + L2 0.18
        Eval losses 2109 = NLL 2092 + KL IC 0,0 + KL II 17,17 + L2 0.18
Batches 11401-11500 in 4.96 sec, Step size: 0.01583
    Training losses 2114 = NLL 2093 + KL IC 0,0 + KL II 21,21 + L2 0.18
        Eval losses 2121 = NLL 2101 + KL IC 0,0 + KL II 20,20 + L2 0.18
Batches 11501-11600 in 4.95 sec, Step size: 0.01567
    Training losses 2114 = NLL 2098 + KL IC 0,0 + KL II 16,16 + L2 0.18
        Eval losses 2117 = NLL 2100 + KL IC 0,0 + KL II 17,17 + L2 0.18
Batches 11601-11700 in 4.95 sec, Step size: 0.01552
    Training losses 2110 = NLL 2092 + KL IC 0,0 + KL II 17,17 + L2 0.18
        Eval losses 2114 = NLL 2096 + KL IC 0,0 + KL II 18,18 + L2 0.18
Batches 11701-11800 in 4.94 sec, Step size: 0.01536
    Training losses 2114 = NLL 2097 + KL IC 0,0 + KL II 17,17 + L2 0.18
        Eval losses 2109 = NLL 2092 + KL IC 0,0 + KL II 16,16 + L2 0.18
Batches 11801-11900 in 4.96 sec, Step size: 0.01521
    Training losses 2112 = NLL 2096 + KL IC 0,0 + KL II 16,16 + L2 0.18
        Eval losses 2112 = NLL 2096 + KL IC 0,0 + KL II 16,16 + L2 0.18
Batches 11901-12000 in 4.95 sec, Step size: 0.01506
    Training losses 2110 = NLL 2093 + KL IC 0,0 + KL II 17,17 + L2 0.18
        Eval losses 2115 = NLL 2098 + KL IC 0,0 + KL II 17,17 + L2 0.18
Batches 12001-12100 in 4.99 sec, Step size: 0.01491
    Training losses 2110 = NLL 2094 + KL IC 0,0 + KL II 16,16 + L2 0.18
        Eval losses 2109 = NLL 2092 + KL IC 0,0 + KL II 16,16 + L2 0.18
Batches 12101-12200 in 4.94 sec, Step size: 0.01476
    Training losses 2115 = NLL 2098 + KL IC 0,0 + KL II 17,17 + L2 0.18
        Eval losses 2107 = NLL 2089 + KL IC 0,0 + KL II 17,17 + L2 0.18
Batches 12201-12300 in 4.92 sec, Step size: 0.01461
    Training losses 2115 = NLL 2097 + KL IC 0,0 + KL II 18,18 + L2 0.18
        Eval losses 2117 = NLL 2099 + KL IC 0,0 + KL II 17,17 + L2 0.18
Batches 12301-12400 in 4.94 sec, Step size: 0.01447
    Training losses 2114 = NLL 2098 + KL IC 0,0 + KL II 16,16 + L2 0.18
        Eval losses 2112 = NLL 2095 + KL IC 0,0 + KL II 16,16 + L2 0.18
Batches 12401-12500 in 4.92 sec, Step size: 0.01432
    Training losses 2117 = NLL 2100 + KL IC 0,0 + KL II 16,16 + L2 0.18
        Eval losses 2113 = NLL 2096 + KL IC 0,0 + KL II 16,16 + L2 0.18
Batches 12501-12600 in 4.97 sec, Step size: 0.01418
    Training losses 2113 = NLL 2095 + KL IC 0,0 + KL II 18,18 + L2 0.18
        Eval losses 2110 = NLL 2092 + KL IC 0,0 + KL II 17,17 + L2 0.18
Batches 12601-12700 in 4.99 sec, Step size: 0.01404
    Training losses 2111 = NLL 2094 + KL IC 0,0 + KL II 17,17 + L2 0.18
        Eval losses 2109 = NLL 2092 + KL IC 0,0 + KL II 16,16 + L2 0.18
Batches 12701-12800 in 4.93 sec, Step size: 0.01390
    Training losses 2114 = NLL 2098 + KL IC 0,0 + KL II 16,16 + L2 0.18
        Eval losses 2116 = NLL 2100 + KL IC 0,0 + KL II 16,16 + L2 0.18
Batches 12801-12900 in 4.92 sec, Step size: 0.01376
    Training losses 2116 = NLL 2099 + KL IC 0,0 + KL II 17,17 + L2 0.18
        Eval losses 2109 = NLL 2093 + KL IC 0,0 + KL II 16,16 + L2 0.18
Batches 12901-13000 in 4.95 sec, Step size: 0.01363
    Training losses 2115 = NLL 2099 + KL IC 0,0 + KL II 16,16 + L2 0.18
        Eval losses 2109 = NLL 2093 + KL IC 0,0 + KL II 16,16 + L2 0.18
Batches 13001-13100 in 4.94 sec, Step size: 0.01349
    Training losses 2113 = NLL 2096 + KL IC 0,0 + KL II 16,16 + L2 0.18
        Eval losses 2114 = NLL 2097 + KL IC 0,0 + KL II 17,17 + L2 0.18
Batches 13101-13200 in 4.94 sec, Step size: 0.01336
    Training losses 2117 = NLL 2101 + KL IC 0,0 + KL II 16,16 + L2 0.18
        Eval losses 2111 = NLL 2094 + KL IC 0,0 + KL II 16,16 + L2 0.18
Batches 13201-13300 in 4.96 sec, Step size: 0.01322
    Training losses 2105 = NLL 2087 + KL IC 0,0 + KL II 18,18 + L2 0.18
        Eval losses 2107 = NLL 2090 + KL IC 0,0 + KL II 17,17 + L2 0.18
Batches 13301-13400 in 5.03 sec, Step size: 0.01309
    Training losses 2110 = NLL 2092 + KL IC 0,0 + KL II 17,17 + L2 0.18
        Eval losses 2110 = NLL 2092 + KL IC 0,0 + KL II 17,17 + L2 0.18
Batches 13401-13500 in 4.96 sec, Step size: 0.01296
    Training losses 2110 = NLL 2093 + KL IC 0,0 + KL II 16,16 + L2 0.18
        Eval losses 2112 = NLL 2095 + KL IC 0,0 + KL II 17,17 + L2 0.18
Batches 13501-13600 in 4.96 sec, Step size: 0.01283
    Training losses 2112 = NLL 2095 + KL IC 0,0 + KL II 17,17 + L2 0.18
        Eval losses 2115 = NLL 2098 + KL IC 0,0 + KL II 17,17 + L2 0.18
Batches 13601-13700 in 4.95 sec, Step size: 0.01270
    Training losses 2112 = NLL 2096 + KL IC 0,0 + KL II 16,16 + L2 0.18
        Eval losses 2119 = NLL 2102 + KL IC 0,0 + KL II 16,16 + L2 0.18
Batches 13701-13800 in 4.93 sec, Step size: 0.01258
    Training losses 2116 = NLL 2096 + KL IC 0,0 + KL II 19,19 + L2 0.18
        Eval losses 2109 = NLL 2090 + KL IC 0,0 + KL II 18,18 + L2 0.18
Batches 13801-13900 in 4.94 sec, Step size: 0.01245
    Training losses 2113 = NLL 2096 + KL IC 0,0 + KL II 16,16 + L2 0.18
        Eval losses 2114 = NLL 2097 + KL IC 0,0 + KL II 17,17 + L2 0.18
Batches 13901-14000 in 4.95 sec, Step size: 0.01233
    Training losses 2116 = NLL 2097 + KL IC 0,0 + KL II 18,18 + L2 0.18
        Eval losses 2114 = NLL 2096 + KL IC 0,0 + KL II 17,17 + L2 0.18
Batches 14001-14100 in 4.92 sec, Step size: 0.01221
    Training losses 2108 = NLL 2091 + KL IC 0,0 + KL II 17,17 + L2 0.18
        Eval losses 2121 = NLL 2104 + KL IC 0,0 + KL II 17,17 + L2 0.18
Batches 14101-14200 in 4.94 sec, Step size: 0.01208
    Training losses 2116 = NLL 2097 + KL IC 0,0 + KL II 18,18 + L2 0.18
        Eval losses 2114 = NLL 2096 + KL IC 0,0 + KL II 17,17 + L2 0.18
Batches 14201-14300 in 4.98 sec, Step size: 0.01196
    Training losses 2109 = NLL 2092 + KL IC 0,0 + KL II 17,17 + L2 0.18
        Eval losses 2115 = NLL 2098 + KL IC 0,0 + KL II 16,16 + L2 0.18
Batches 14301-14400 in 4.93 sec, Step size: 0.01185
    Training losses 2112 = NLL 2095 + KL IC 0,0 + KL II 17,17 + L2 0.18
        Eval losses 2112 = NLL 2095 + KL IC 0,0 + KL II 16,16 + L2 0.18
Batches 14401-14500 in 5.05 sec, Step size: 0.01173
    Training losses 2112 = NLL 2095 + KL IC 0,0 + KL II 17,17 + L2 0.18
        Eval losses 2109 = NLL 2092 + KL IC 0,0 + KL II 17,17 + L2 0.18
Batches 14501-14600 in 4.94 sec, Step size: 0.01161
    Training losses 2111 = NLL 2094 + KL IC 0,0 + KL II 17,17 + L2 0.18
        Eval losses 2111 = NLL 2093 + KL IC 0,0 + KL II 17,17 + L2 0.18
Batches 14601-14700 in 4.92 sec, Step size: 0.01150
    Training losses 2117 = NLL 2100 + KL IC 0,0 + KL II 17,17 + L2 0.18
        Eval losses 2113 = NLL 2096 + KL IC 0,0 + KL II 17,17 + L2 0.18
Batches 14701-14800 in 4.92 sec, Step size: 0.01138
    Training losses 2112 = NLL 2094 + KL IC 0,0 + KL II 18,18 + L2 0.18
        Eval losses 2112 = NLL 2094 + KL IC 0,0 + KL II 17,17 + L2 0.18
Batches 14801-14900 in 4.95 sec, Step size: 0.01127
    Training losses 2112 = NLL 2096 + KL IC 0,0 + KL II 16,16 + L2 0.18
        Eval losses 2107 = NLL 2090 + KL IC 0,0 + KL II 16,16 + L2 0.18
Batches 14901-15000 in 4.91 sec, Step size: 0.01116
    Training losses 2112 = NLL 2095 + KL IC 0,0 + KL II 17,17 + L2 0.18
        Eval losses 2117 = NLL 2100 + KL IC 0,0 + KL II 16,16 + L2 0.18
Batches 15001-15100 in 4.91 sec, Step size: 0.01104
    Training losses 2110 = NLL 2093 + KL IC 0,0 + KL II 17,17 + L2 0.18
        Eval losses 2116 = NLL 2099 + KL IC 0,0 + KL II 16,16 + L2 0.18
Batches 15101-15200 in 4.95 sec, Step size: 0.01093
    Training losses 2111 = NLL 2095 + KL IC 0,0 + KL II 16,16 + L2 0.18
        Eval losses 2112 = NLL 2094 + KL IC 0,0 + KL II 17,17 + L2 0.18
Batches 15201-15300 in 4.98 sec, Step size: 0.01083
    Training losses 2107 = NLL 2090 + KL IC 0,0 + KL II 17,17 + L2 0.18
        Eval losses 2112 = NLL 2095 + KL IC 0,0 + KL II 17,17 + L2 0.18
Batches 15301-15400 in 4.94 sec, Step size: 0.01072
    Training losses 2111 = NLL 2093 + KL IC 0,0 + KL II 17,17 + L2 0.18
        Eval losses 2106 = NLL 2089 + KL IC 0,0 + KL II 17,17 + L2 0.18
Batches 15401-15500 in 4.93 sec, Step size: 0.01061
    Training losses 2111 = NLL 2094 + KL IC 0,0 + KL II 17,17 + L2 0.18
        Eval losses 2114 = NLL 2097 + KL IC 0,0 + KL II 16,16 + L2 0.18
Batches 15501-15600 in 4.95 sec, Step size: 0.01051
    Training losses 2113 = NLL 2097 + KL IC 0,0 + KL II 16,16 + L2 0.18
        Eval losses 2110 = NLL 2095 + KL IC 0,0 + KL II 15,15 + L2 0.18
Batches 15601-15700 in 4.96 sec, Step size: 0.01040
    Training losses 2112 = NLL 2094 + KL IC 0,0 + KL II 17,17 + L2 0.18
        Eval losses 2108 = NLL 2091 + KL IC 0,0 + KL II 17,17 + L2 0.18
Batches 15701-15800 in 4.95 sec, Step size: 0.01030
    Training losses 2110 = NLL 2092 + KL IC 0,0 + KL II 18,18 + L2 0.18
        Eval losses 2116 = NLL 2098 + KL IC 0,0 + KL II 17,17 + L2 0.18
Batches 15801-15900 in 4.94 sec, Step size: 0.01020
    Training losses 2118 = NLL 2100 + KL IC 0,0 + KL II 17,17 + L2 0.18
        Eval losses 2112 = NLL 2095 + KL IC 0,0 + KL II 17,17 + L2 0.18
Batches 15901-16000 in 4.93 sec, Step size: 0.01009
    Training losses 2109 = NLL 2093 + KL IC 0,0 + KL II 16,16 + L2 0.18
        Eval losses 2114 = NLL 2097 + KL IC 0,0 + KL II 16,16 + L2 0.18
Batches 16001-16100 in 4.98 sec, Step size: 0.00999
    Training losses 2108 = NLL 2091 + KL IC 0,0 + KL II 16,16 + L2 0.18
        Eval losses 2112 = NLL 2095 + KL IC 0,0 + KL II 16,16 + L2 0.18
Batches 16101-16200 in 4.93 sec, Step size: 0.00989
    Training losses 2111 = NLL 2095 + KL IC 0,0 + KL II 16,16 + L2 0.18
        Eval losses 2110 = NLL 2094 + KL IC 0,0 + KL II 16,16 + L2 0.18
Batches 16201-16300 in 4.93 sec, Step size: 0.00980
    Training losses 2109 = NLL 2092 + KL IC 0,0 + KL II 17,17 + L2 0.18
        Eval losses 2112 = NLL 2095 + KL IC 0,0 + KL II 17,17 + L2 0.18
Batches 16301-16400 in 4.90 sec, Step size: 0.00970
    Training losses 2120 = NLL 2103 + KL IC 0,0 + KL II 16,16 + L2 0.18
        Eval losses 2122 = NLL 2104 + KL IC 0,0 + KL II 17,17 + L2 0.18
Batches 16401-16500 in 4.93 sec, Step size: 0.00960
    Training losses 2110 = NLL 2093 + KL IC 0,0 + KL II 17,17 + L2 0.18
        Eval losses 2107 = NLL 2090 + KL IC 0,0 + KL II 17,17 + L2 0.18
Batches 16501-16600 in 4.95 sec, Step size: 0.00951
    Training losses 2110 = NLL 2093 + KL IC 0,0 + KL II 16,16 + L2 0.18
        Eval losses 2112 = NLL 2095 + KL IC 0,0 + KL II 16,16 + L2 0.18
Batches 16601-16700 in 4.94 sec, Step size: 0.00941
    Training losses 2109 = NLL 2092 + KL IC 0,0 + KL II 17,17 + L2 0.18
        Eval losses 2108 = NLL 2091 + KL IC 0,0 + KL II 16,16 + L2 0.18
Batches 16701-16800 in 5.01 sec, Step size: 0.00932
    Training losses 2108 = NLL 2091 + KL IC 0,0 + KL II 16,16 + L2 0.18
        Eval losses 2112 = NLL 2095 + KL IC 0,0 + KL II 16,16 + L2 0.18
Batches 16801-16900 in 4.94 sec, Step size: 0.00923
    Training losses 2110 = NLL 2093 + KL IC 0,0 + KL II 17,17 + L2 0.18
        Eval losses 2109 = NLL 2092 + KL IC 0,0 + KL II 16,16 + L2 0.18
Batches 16901-17000 in 4.96 sec, Step size: 0.00913
    Training losses 2113 = NLL 2096 + KL IC 0,0 + KL II 17,17 + L2 0.18
        Eval losses 2110 = NLL 2093 + KL IC 0,0 + KL II 16,16 + L2 0.18
Batches 17001-17100 in 4.91 sec, Step size: 0.00904
    Training losses 2112 = NLL 2094 + KL IC 0,0 + KL II 17,17 + L2 0.18
        Eval losses 2110 = NLL 2093 + KL IC 0,0 + KL II 16,16 + L2 0.18
Batches 17101-17200 in 4.96 sec, Step size: 0.00895
    Training losses 2110 = NLL 2094 + KL IC 0,0 + KL II 16,16 + L2 0.18
        Eval losses 2108 = NLL 2091 + KL IC 0,0 + KL II 16,16 + L2 0.18
Batches 17201-17300 in 4.97 sec, Step size: 0.00886
    Training losses 2112 = NLL 2094 + KL IC 0,0 + KL II 17,17 + L2 0.18
        Eval losses 2105 = NLL 2087 + KL IC 0,0 + KL II 17,17 + L2 0.18
Batches 17301-17400 in 4.97 sec, Step size: 0.00878
    Training losses 2109 = NLL 2092 + KL IC 0,0 + KL II 16,16 + L2 0.18
        Eval losses 2108 = NLL 2092 + KL IC 0,0 + KL II 16,16 + L2 0.18
Batches 17401-17500 in 4.98 sec, Step size: 0.00869
    Training losses 2108 = NLL 2091 + KL IC 0,0 + KL II 17,17 + L2 0.18
        Eval losses 2108 = NLL 2090 + KL IC 0,0 + KL II 17,17 + L2 0.18
Batches 17501-17600 in 4.94 sec, Step size: 0.00860
    Training losses 2114 = NLL 2096 + KL IC 0,0 + KL II 17,17 + L2 0.18
        Eval losses 2112 = NLL 2094 + KL IC 0,0 + KL II 17,17 + L2 0.18
Batches 17601-17700 in 4.92 sec, Step size: 0.00852
    Training losses 2109 = NLL 2092 + KL IC 0,0 + KL II 17,17 + L2 0.18
        Eval losses 2107 = NLL 2090 + KL IC 0,0 + KL II 17,17 + L2 0.18
Batches 17701-17800 in 4.97 sec, Step size: 0.00843
    Training losses 2110 = NLL 2093 + KL IC 0,0 + KL II 16,16 + L2 0.18
        Eval losses 2115 = NLL 2098 + KL IC 0,0 + KL II 16,16 + L2 0.18
Batches 17801-17900 in 4.97 sec, Step size: 0.00835
    Training losses 2111 = NLL 2095 + KL IC 0,0 + KL II 16,16 + L2 0.18
        Eval losses 2110 = NLL 2093 + KL IC 0,0 + KL II 17,17 + L2 0.18
Batches 17901-18000 in 4.95 sec, Step size: 0.00826
    Training losses 2112 = NLL 2095 + KL IC 0,0 + KL II 16,16 + L2 0.18
        Eval losses 2111 = NLL 2094 + KL IC 0,0 + KL II 16,16 + L2 0.18
Batches 18001-18100 in 4.92 sec, Step size: 0.00818
    Training losses 2111 = NLL 2095 + KL IC 0,0 + KL II 16,16 + L2 0.18
        Eval losses 2115 = NLL 2098 + KL IC 0,0 + KL II 16,16 + L2 0.18
Batches 18101-18200 in 4.94 sec, Step size: 0.00810
    Training losses 2114 = NLL 2097 + KL IC 0,0 + KL II 17,17 + L2 0.18
        Eval losses 2114 = NLL 2096 + KL IC 0,0 + KL II 17,17 + L2 0.18
Batches 18201-18300 in 4.92 sec, Step size: 0.00802
    Training losses 2111 = NLL 2094 + KL IC 0,0 + KL II 17,17 + L2 0.18
        Eval losses 2111 = NLL 2095 + KL IC 0,0 + KL II 16,16 + L2 0.18
Batches 18301-18400 in 4.95 sec, Step size: 0.00794
    Training losses 2111 = NLL 2095 + KL IC 0,0 + KL II 16,16 + L2 0.18
        Eval losses 2114 = NLL 2097 + KL IC 0,0 + KL II 16,16 + L2 0.18
Batches 18401-18500 in 4.91 sec, Step size: 0.00786
    Training losses 2112 = NLL 2095 + KL IC 0,0 + KL II 17,17 + L2 0.18
        Eval losses 2113 = NLL 2096 + KL IC 0,0 + KL II 17,17 + L2 0.18
Batches 18501-18600 in 4.96 sec, Step size: 0.00778
    Training losses 2115 = NLL 2097 + KL IC 0,0 + KL II 18,18 + L2 0.18
        Eval losses 2106 = NLL 2090 + KL IC 0,0 + KL II 16,16 + L2 0.18
Batches 18601-18700 in 4.94 sec, Step size: 0.00771
    Training losses 2113 = NLL 2096 + KL IC 0,0 + KL II 16,16 + L2 0.18
        Eval losses 2110 = NLL 2093 + KL IC 0,0 + KL II 16,16 + L2 0.18
Batches 18701-18800 in 4.97 sec, Step size: 0.00763
    Training losses 2110 = NLL 2092 + KL IC 0,0 + KL II 17,17 + L2 0.18
        Eval losses 2110 = NLL 2092 + KL IC 0,0 + KL II 17,17 + L2 0.18
Batches 18801-18900 in 4.95 sec, Step size: 0.00755
    Training losses 2114 = NLL 2098 + KL IC 0,0 + KL II 16,16 + L2 0.18
        Eval losses 2114 = NLL 2097 + KL IC 0,0 + KL II 17,17 + L2 0.18
Batches 18901-19000 in 4.96 sec, Step size: 0.00748
    Training losses 2114 = NLL 2097 + KL IC 0,0 + KL II 17,17 + L2 0.18
        Eval losses 2107 = NLL 2089 + KL IC 0,0 + KL II 18,18 + L2 0.18
Batches 19001-19100 in 4.94 sec, Step size: 0.00740
    Training losses 2109 = NLL 2091 + KL IC 0,0 + KL II 17,17 + L2 0.18
        Eval losses 2115 = NLL 2097 + KL IC 0,0 + KL II 17,17 + L2 0.18
Batches 19101-19200 in 4.93 sec, Step size: 0.00733
    Training losses 2108 = NLL 2091 + KL IC 0,0 + KL II 16,16 + L2 0.18
        Eval losses 2112 = NLL 2096 + KL IC 0,0 + KL II 16,16 + L2 0.18
Batches 19201-19300 in 5.01 sec, Step size: 0.00726
    Training losses 2105 = NLL 2088 + KL IC 0,0 + KL II 16,16 + L2 0.18
        Eval losses 2112 = NLL 2095 + KL IC 0,0 + KL II 16,16 + L2 0.18
Batches 19301-19400 in 4.97 sec, Step size: 0.00718
    Training losses 2110 = NLL 2094 + KL IC 0,0 + KL II 16,16 + L2 0.18
        Eval losses 2115 = NLL 2099 + KL IC 0,0 + KL II 16,16 + L2 0.18
Batches 19401-19500 in 4.91 sec, Step size: 0.00711
    Training losses 2110 = NLL 2094 + KL IC 0,0 + KL II 15,15 + L2 0.18
        Eval losses 2111 = NLL 2095 + KL IC 0,0 + KL II 15,15 + L2 0.18
Batches 19501-19600 in 4.98 sec, Step size: 0.00704
    Training losses 2114 = NLL 2096 + KL IC 0,0 + KL II 17,17 + L2 0.18
        Eval losses 2112 = NLL 2095 + KL IC 0,0 + KL II 17,17 + L2 0.18
Batches 19601-19700 in 4.94 sec, Step size: 0.00697
    Training losses 2109 = NLL 2091 + KL IC 0,0 + KL II 17,17 + L2 0.18
        Eval losses 2110 = NLL 2093 + KL IC 0,0 + KL II 17,17 + L2 0.18
Batches 19701-19800 in 4.93 sec, Step size: 0.00690
    Training losses 2103 = NLL 2086 + KL IC 0,0 + KL II 17,17 + L2 0.18
        Eval losses 2107 = NLL 2091 + KL IC 0,0 + KL II 16,16 + L2 0.18
Batches 19801-19900 in 4.94 sec, Step size: 0.00683
    Training losses 2110 = NLL 2093 + KL IC 0,0 + KL II 17,17 + L2 0.18
        Eval losses 2110 = NLL 2092 + KL IC 0,0 + KL II 17,17 + L2 0.18
Batches 19901-20000 in 4.95 sec, Step size: 0.00677
    Training losses 2109 = NLL 2092 + KL IC 0,0 + KL II 16,16 + L2 0.18
        Eval losses 2105 = NLL 2088 + KL IC 0,0 + KL II 16,16 + L2 0.18

In [28]:
# Plot the training details
x = onp.arange(0, num_batches, print_every)
plt.figure(figsize=(20,6))
plt.subplot(251)
plt.plot(x, opt_details['tlosses']['total'], 'k')
plt.ylabel('Training')
plt.title('Total loss')
plt.subplot(252)
plt.plot(x, opt_details['tlosses']['nlog_p_xgz'], 'b')
plt.title('Negative log p(z|x)')
plt.subplot(253)
plt.plot(x, opt_details['tlosses']['kl_ii'], 'r')
plt.title('KL inferred inputs')
plt.subplot(254)
plt.plot(x, opt_details['tlosses']['kl_g0'], 'g')
plt.title('KL initial state')
plt.subplot(255)
plt.plot(x, opt_details['tlosses']['l2'], 'c')
plt.xlabel('Training batch')
plt.title('L2 loss')
plt.subplot(256)
plt.plot(x, opt_details['elosses']['total'], 'k')
plt.xlabel('Training batch')
plt.ylabel('Evaluation')
plt.subplot(257)
plt.plot(x, opt_details['tlosses']['nlog_p_xgz'], 'b')
plt.xlabel('Training batch')
plt.subplot(258)
plt.plot(x, opt_details['elosses']['kl_ii'], 'r')
plt.xlabel('Training batch')
plt.subplot(259)
plt.plot(x, opt_details['elosses']['kl_g0'], 'g')
plt.xlabel('Training batch');



In [29]:
# See the effect of the KL warmup, which is shown 
# by the KL penalities without the warmup scaling. 
plt.figure(figsize=(7,4))
plt.subplot(221)
plt.plot(x, opt_details['tlosses']['kl_ii_prescale'], 'r--')
plt.ylabel('Training')
plt.subplot(222)
plt.plot(x, opt_details['tlosses']['kl_g0_prescale'], 'g--')
plt.subplot(223)
plt.plot(x, opt_details['elosses']['kl_ii_prescale'], 'r--')
plt.ylabel('Evaluation')
plt.xlabel('Training batch')
plt.subplot(224)
plt.plot(x, opt_details['elosses']['kl_g0_prescale'], 'g--')
plt.xlabel('Training batch');


Save the LFADS model parameters


In [30]:
fname_uniquifier = datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
network_fname = ('trained_params_' + rnn_type + '_' + task_type + '_' + \
                 fname_uniquifier + '.npz')
network_path = os.path.join(output_dir, network_fname)

# Note we are just using numpy save instead of h5 because the LFADS parameter 
# is nested dictionaries, something I couldn't get h5 to save down easily.
print("Saving parameters: ", network_path)
onp.savez(network_path, trained_params)


Saving parameters:  /tmp/lfads/output/trained_params_lfads_integrator_2019-06-19_23:12:26.npz

In [31]:
# After training, you can load these up, after locating the save file.
if False:
    loaded_params = onp.load(network_path, allow_pickle=True)
    trained_params = loaded_params['arr_0'].item()

LFADS Visualization

To plot the results of LFADS, namely the inferred quantities such as the inferred inputs, factors, or rates, we have to do a sample-and-average operation. Remember, the latent variables for LFADS are the initial state and the inferred inputs, and they are per-trial stochastic codes, even for a single trial. To get good inference for a given trial, we sample a large number of times from these per-trial stochastic latent variables, run the generator forward, and then average all the quantities of interest over the samples.

If LFADS were linear a linear model, it would be equivalent to do the much more efficient decode of the posterior means, that is, just take the mean of the initial state distribution and the mean of the inferred input distribution, and then run the decoder one time. (This, btw, is a great exercise to the tutorial reader: implement posterior-mean decoding in this tutorial.)

Here we use batching and take the 'posterior average' using batch number of samples from the latent variable distributions.

So the main result of this tutorial, the moment you've been waiting for, is the comparison between the true rates of the integrator RNN, and the inferred rates by LFADS, and the true input to the integrator RNN and the inferred inputs given by LFADS. You can see how well we did by generating lots of trials here.


In [47]:
# Plot a bunch of examples of eval trials run through LFADS.
reload(plotting)
#reload(lfads)

def plot_rescale_fun(a): 
    fac = max_firing_rate * data_dt
    return renormed_fun(a) * fac


bidx = my_example_bidx - eval_data_offset
bidx = 0

nexamples_to_save = 1
for eidx in range(nexamples_to_save):
    fkey = random.fold_in(key, eidx)
    psa_example = eval_data[bidx,:,:].astype(np.float32)
    psa_dict = lfads.posterior_sample_and_average_jit(trained_params, lfads_hps, fkey, psa_example)

    # The inferred input and true input are rescaled and shifted via 
    # linear regression to match, as there is an identifiability issue. there.
    plotting.plot_lfads(psa_example, psa_dict,
                        data_dict, eval_data_offset+bidx, plot_rescale_fun)


bidx:  18432

And coming back to our example signal, how well does LFADS do on it, compared to the other much easier to implement methods? A noticeable improvement on inferring the underlying rate.


In [46]:
plt.figure(figsize=(16,4))

plt.subplot(141)
plt.plot(my_signal, 'r');
plt.stem(my_signal_spikified);
_, _, r2_spike, _, _ = scipy.stats.linregress(my_signal_spikified, my_signal)
plt.title('Raw spikes R^2=%.3f' % (r2_spike))
plt.legend(('True rate', 'Spikes'));


plt.subplot(142)
plt.plot(my_signal, 'r');
plt.plot(my_filtered_spikes);
_, _, c_tfilt, _, _ = scipy.stats.linregress(my_filtered_spikes, my_signal)
plt.title("Temporal filtering  R^2=%.3f" % (c_tfilt**2));
plt.legend(('True rate', 'Filtered spikes'));

plt.subplot(143)
plt.plot(my_signal, 'r')
plt.plot(my_example_ipca[:,my_example_hidx])
_, _, c_sfilt, _, _ = scipy.stats.linregress(my_example_ipca[:,my_example_hidx], my_signal)
plt.legend(('True rate', 'PCA smoothed spikes'))
plt.title('Spatial filtering R^2=%.3f' % (c_sfilt**2));

plt.subplot(144)
plt.plot(my_signal, 'r')
my_lfads_rate = onp.exp(psa_dict['lograte_t'][:,my_example_hidx])
plt.plot(my_lfads_rate)
_, _, r2_lfads, _, _ = scipy.stats.linregress(my_lfads_rate, my_signal)
plt.legend(('True rate', 'LFADS rate'))
plt.title('LFADS "filtering" R^2=%.3f' % (r2_lfads));


That single example can't tell the whole story so let us look at the average. LFADS is much better than spatial averaging across a large set of trials.

Take an average over all the hidden units in 100 evaluation trials.


In [34]:
nexamples = 1000
r2_sfilts = onp.zeros(nexamples*data_dim)
r2_lfadss = onp.zeros(nexamples*data_dim)
eidx = 0
for bidx in range(nexamples):
    ebidx = eval_data_offset + bidx
    
    # Get the LFADS decode for this trial.
    fkey = random.fold_in(key, bidx)
    psa_example = eval_data[bidx,:,:].astype(np.float32)
    psa_dict = lfads.posterior_sample_and_average_jit(trained_params, lfads_hps, fkey, psa_example)
    
    # Get the spatially smoothed trial.
    trial_rates = scale*renormed_data[ebidx, :, :]
    trial_spikes = data_bxtxn[ebidx, :, :]
    spikes_pca = pca.transform(trial_spikes)
    spikes_ipca = pca.inverse_transform(spikes_pca)
    
    for hidx in range(data_dim):
        sig = trial_rates[:, hidx]
        ipca_rate = spikes_ipca[:,hidx]
        lfads_rate = onp.exp(psa_dict['lograte_t'][:,hidx])
        _, _, cc_sfilt, _, _ = scipy.stats.linregress(ipca_rate, sig)
        _, _, cc_lfads, _, _ = scipy.stats.linregress(lfads_rate, sig)

        r2_sfilts[eidx] = cc_sfilt**2
        r2_lfadss[eidx] = cc_lfads**2
        eidx += 1
    
plt.figure(figsize=(8,4))
plt.subplot(121)
plt.hist(r2_sfilts, 50)
plt.title('Spatial filtering, hist of R^2, <%.3f>' % (onp.mean(r2_sfilts)))
plt.xlim([-.5, 1.0])

plt.subplot(122)
plt.hist(r2_lfadss, 50);
plt.title('LFADS filtering, hist of R^2, <%.3f>' % (onp.mean(r2_lfadss)));
plt.xlim([-.5, 1.0]);


Compare the inferred inputs learned by LFADS to the actual inputs to the integrator RNN.

Finally, we can look at the average correlation between the inferred inputs and the true inputs to the integrator RNN. The inferred input can be arbitrarily scaled or rotated, so we first compute a linear regression, to scale the inferred input correctly, and then get the $R^2$.


In [35]:
r2_iis = []
nexamples = 1000
for bidx in range(nexamples):
    ebidx = eval_data_offset + bidx
    
    # Get the LFADS decode for this trial.
    psa_example = eval_data[bidx,:,:].astype(np.float32)
    fkey = random.fold_in(key, bidx)
    psa_dict = lfads.posterior_sample_and_average_jit(trained_params, lfads_hps, fkey, psa_example)
    
    # Get the true input and inferred input
    true_input = onp.squeeze(data_dict['inputs'][ebidx])
    inferred_input = onp.squeeze(psa_dict['ii_t'])
    slope, intercept, _, _, _ = scipy.stats.linregress(inferred_input, true_input)
    _, _, cc_ii, _, _ = scipy.stats.linregress(slope * inferred_input + intercept, true_input)
    
    r2_iis.append(cc_ii**2)
    
r2_iis = onp.array(r2_iis)

plt.hist(r2_iis, 20)
plt.title('Correlation between rescaled inferrred inputs and true inputs, hist of R^2, <%.3f>' % (onp.mean(r2_iis)))
plt.xlim([0.0, 1.0]);


Compare the inferred initial state for the LFADS generator to the actual initial state of the integrator RNN.

To finish, we can examine the relationship between the initial condition (h0) of the integrator RNN and the inferred initial condition of the LFADS generator. The color we use is the readout of the integrator RNN's initial state, so basically, the state of the line attractor before further information is presented. In the integrator RNN example, we made sure to seed these initial states with various values along the line attractor, so we expect a line of coloration.


In [36]:
ntrials = 1000
true_h0s = onp.zeros([ntrials, data_dim])
ic_means = onp.zeros([ntrials, gen_dim])
colors = onp.zeros(ntrials)
for bidx in range(ntrials):
    ebidx = eval_data_offset + bidx
    
    # Get the LFADS decode for this trial.
    psa_example = eval_data[bidx,:,:].astype(np.float32)
    fkey = random.fold_in(key, bidx)
    #psa_dict = lfads.posterior_sample_and_average_jit(trained_params, lfads_hps, fkey, psa_example)
    lfads_results = lfads.lfads_jit(trained_params, lfads_hps, fkey, psa_example, 1.0)
    # Get the true initial condition (and the readout of the true h0 for color)
    # Get the inferred input from LFADS
    true_h0s[bidx,:] = data_dict['h0s'][ebidx]
    colors[bidx] = data_dict['outputs'][ebidx][0]
    ic_means[bidx,:] = lfads_results['ic_mean']

In [37]:
from sklearn.manifold import TSNE
plt.figure(figsize=(16,8))
plt.subplot(121)
h0s_embedded = TSNE(n_components=2).fit_transform(true_h0s)
plt.scatter(h0s_embedded[:,0], h0s_embedded[:,1], c=colors)
plt.title('TSNE visualization of integrator RNN intial state')
plt.subplot(122)
ic_means_embedded = TSNE(n_components=2).fit_transform(ic_means)
plt.scatter(ic_means_embedded[:,0], ic_means_embedded[:,1], c=colors);
plt.title('TSNE visualziation of LFADS inferred intial generator state.')


Out[37]:
Text(0.5,1,'TSNE visualziation of LFADS inferred intial generator state.')